{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nS7qzvG6-TfI"
      },
      "source": [
        "# Video SSL Tutorial\n",
        "\n",
        "This tutorial trains a video_ssl with 3D-ResNet-50 (R3D-50) as backbone model from the TensorFlow Model Garden package (tensorflow-models).\n",
        "\n",
        "[Model Garden](https://www.tensorflow.org/tfmodels) contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.\n",
        "\n",
        "**Dataset:** [UCF_101](https://www.tensorflow.org/datasets/catalog/ucf101)\n",
        "* A 101-label video classification dataset\n",
        "\n",
        "**This tutorial demonstrates how to:**\n",
        "\n",
        "* Use models from the TensorFlow Models package.\n",
        "* Train/Fine-tune a pre-built [video_ssl](https://arxiv.org/abs/2008.03800) for Video Classification.\n",
        "* Export the trained/tuned video_ssl model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HYrARPmK-OZI"
      },
      "source": [
        "## Install cecessary libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XRUierWnXulP"
      },
      "outputs": [],
      "source": [
        "!pip install -U -q \"tensorflow\" \"tensorflow_addons\" \"immutabledict\" \"tensorflow_datasets\"\n",
        "!pip install -U -q remotezip tqdm opencv-python einops\n",
        "!pip install -U -q git+https://github.com/tensorflow/docs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K5W_31qC-GgY"
      },
      "source": [
        "## Clone models repository"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WJWNaEAXdgI4"
      },
      "outputs": [],
      "source": [
        "!git clone https://github.com/tensorflow/models.git"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nPbayvv5Pyg-"
      },
      "source": [
        "### Set models as current working dir"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "g8hcYIlicrEn"
      },
      "outputs": [],
      "source": [
        "%cd ./models/"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DnF43pn0-ARp"
      },
      "source": [
        "## Import cecessary libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "or0N8bbktR03"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import tqdm\n",
        "import random\n",
        "import pathlib\n",
        "import pprint\n",
        "import itertools\n",
        "import imageio\n",
        "import collections\n",
        "\n",
        "import cv2\n",
        "import einops\n",
        "import numpy as np\n",
        "import remotezip as rz\n",
        "import seaborn as sns\n",
        "import matplotlib.pyplot as plt\n",
        "import tensorflow as tf\n",
        "\n",
        "from IPython import display\n",
        "from tensorflow_docs.vis import embed\n",
        "\n",
        "pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation\n",
        "print(tf.__version__) # Check the version of tensorflow used"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r1hmO0Zxkz_Q"
      },
      "source": [
        "## Import required modules from vidoe_ssl for running the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EopVEolTkkpd"
      },
      "outputs": [],
      "source": [
        "from official.core import task_factory\n",
        "from official.core import train_lib\n",
        "from official.core import train_utils\n",
        "from official.modeling import performance\n",
        "from official.vision.data import tfrecord_lib\n",
        "\n",
        "from official.projects.video_ssl.configs import video_ssl as exp_cfg\n",
        "from official.projects.video_ssl.modeling import video_ssl_model\n",
        "from official.projects.video_ssl.tasks import linear_eval\n",
        "from official.projects.video_ssl.tasks import pretrain\n",
        "from official.vision import registry_imports\n",
        "from official.vision.serving import export_saved_model_lib"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sIS-qexL968P"
      },
      "source": [
        "## Download pretrained weights"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "10t-SR6XAXgO"
      },
      "outputs": [],
      "source": [
        "!wget https://storage.googleapis.com/tf_model_garden/vision/cvrl/r3d_1x_k600_800ep.tar.gz -P /content/\n",
        "!tar -xvf /content/r3d_1x_k600_800ep.tar.gz -C ../\n",
        "!rm ../r3d_1x_k600_800ep.tar.gz"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xXix2TCo-yAk"
      },
      "source": [
        "## Download Subdataset of UCF_101"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SAE7S0ZCcHOT"
      },
      "outputs": [],
      "source": [
        "# @title Load and preprocess video data\n",
        "def list_files_per_class(zip_url):\n",
        "  \"\"\"\n",
        "    List the files in each class of the dataset given the zip URL.\n",
        "\n",
        "    Args:\n",
        "      zip_url: URL from which the files can be unzipped.\n",
        "\n",
        "    Return:\n",
        "      files: List of files in each of the classes.\n",
        "  \"\"\"\n",
        "  files = []\n",
        "  with rz.RemoteZip(URL) as zip:\n",
        "    for zip_info in zip.infolist():\n",
        "      files.append(zip_info.filename)\n",
        "  return files\n",
        "\n",
        "def get_class(fname):\n",
        "  \"\"\"\n",
        "    Retrieve the name of the class given a filename.\n",
        "\n",
        "    Args:\n",
        "      fname: Name of the file in the UCF101 dataset.\n",
        "\n",
        "    Return:\n",
        "      Class that the file belongs to.\n",
        "  \"\"\"\n",
        "  return fname.split('_')[-3]\n",
        "\n",
        "def get_files_per_class(files):\n",
        "  \"\"\"\n",
        "    Retrieve the files that belong to each class.\n",
        "\n",
        "    Args:\n",
        "      files: List of files in the dataset.\n",
        "\n",
        "    Return:\n",
        "      Dictionary of class names (key) and files (values).\n",
        "  \"\"\"\n",
        "  files_for_class = collections.defaultdict(list)\n",
        "  for fname in files:\n",
        "    class_name = get_class(fname)\n",
        "    files_for_class[class_name].append(fname)\n",
        "  return files_for_class\n",
        "\n",
        "def download_from_zip(zip_url, to_dir, file_names):\n",
        "  \"\"\"\n",
        "    Download the contents of the zip file from the zip URL.\n",
        "\n",
        "    Args:\n",
        "      zip_url: Zip URL containing data.\n",
        "      to_dir: Directory to download data to.\n",
        "      file_names: Names of files to download.\n",
        "  \"\"\"\n",
        "  with rz.RemoteZip(zip_url) as zip:\n",
        "    for fn in tqdm.tqdm(file_names):\n",
        "      class_name = get_class(fn)\n",
        "      zip.extract(fn, str(to_dir / class_name))\n",
        "      unzipped_file = to_dir / class_name / fn\n",
        "\n",
        "      fn = pathlib.Path(fn).parts[-1]\n",
        "      output_file = to_dir / class_name / fn\n",
        "      unzipped_file.rename(output_file,)\n",
        "\n",
        "def split_class_lists(files_for_class, count):\n",
        "  \"\"\"\n",
        "    Returns the list of files belonging to a subset of data as well as the remainder of\n",
        "    files that need to be downloaded.\n",
        "\n",
        "    Args:\n",
        "      files_for_class: Files belonging to a particular class of data.\n",
        "      count: Number of files to download.\n",
        "\n",
        "    Return:\n",
        "      split_files: Files belonging to the subset of data.\n",
        "      remainder: Dictionary of the remainder of files that need to be downloaded.\n",
        "  \"\"\"\n",
        "  split_files = []\n",
        "  remainder = {}\n",
        "  for cls in files_for_class:\n",
        "    split_files.extend(files_for_class[cls][:count])\n",
        "    remainder[cls] = files_for_class[cls][count:]\n",
        "  return split_files, remainder\n",
        "\n",
        "def download_ufc_101_subset(zip_url, num_classes, splits, download_dir):\n",
        "  \"\"\"\n",
        "    Download a subset of the UFC101 dataset and split them into various parts, such as\n",
        "    training, validation, and test.\n",
        "\n",
        "    Args:\n",
        "      zip_url: Zip URL containing data.\n",
        "      num_classes: Number of labels.\n",
        "      splits: Dictionary specifying the training, validation, test, etc. (key) division of data\n",
        "              (value is number of files per split).\n",
        "      download_dir: Directory to download data to.\n",
        "\n",
        "    Return:\n",
        "      dir: Posix path of the resulting directories containing the splits of data.\n",
        "  \"\"\"\n",
        "  files = list_files_per_class(zip_url)\n",
        "  for f in files:\n",
        "    tokens = f.split('/')\n",
        "    if len(tokens) \u003c= 2:\n",
        "      files.remove(f) # Remove that item from the list if it does not have a filename\n",
        "\n",
        "  files_for_class = get_files_per_class(files)\n",
        "\n",
        "  classes = list(files_for_class.keys())[:num_classes]\n",
        "\n",
        "  for cls in classes:\n",
        "    new_files_for_class = files_for_class[cls]\n",
        "    random.shuffle(new_files_for_class)\n",
        "    files_for_class[cls] = new_files_for_class\n",
        "\n",
        "  # Only use the number of classes you want in the dictionary\n",
        "  files_for_class = {x: files_for_class[x] for x in list(files_for_class)[:num_classes]}\n",
        "\n",
        "  dirs = {}\n",
        "  for split_name, split_count in splits.items():\n",
        "    print(split_name, \":\")\n",
        "    split_dir = download_dir / split_name\n",
        "    split_files, files_for_class = split_class_lists(files_for_class, split_count)\n",
        "    download_from_zip(zip_url, split_dir, split_files)\n",
        "    dirs[split_name] = split_dir\n",
        "\n",
        "  return dirs\n",
        "\n",
        "def format_frames(frame, output_size):\n",
        "  \"\"\"\n",
        "    Pad and resize an image from a video.\n",
        "\n",
        "    Args:\n",
        "      frame: Image that needs to resized and padded.\n",
        "      output_size: Pixel size of the output frame image.\n",
        "\n",
        "    Return:\n",
        "      Formatted frame with padding of specified output size.\n",
        "  \"\"\"\n",
        "  frame = tf.image.convert_image_dtype(frame, tf.uint8)\n",
        "  frame = tf.image.resize_with_pad(frame, *output_size)\n",
        "  return frame\n",
        "\n",
        "def frames_from_video_file(video_path, n_frames, output_size = (224,224), frame_step = 15):\n",
        "  \"\"\"\n",
        "    Creates frames from each video file present for each category.\n",
        "\n",
        "    Args:\n",
        "      video_path: File path to the video.\n",
        "      n_frames: Number of frames to be created per video file.\n",
        "      output_size: Pixel size of the output frame image.\n",
        "\n",
        "    Return:\n",
        "      An NumPy array of frames in the shape of (n_frames, height, width, channels).\n",
        "  \"\"\"\n",
        "  # Read each video frame by frame\n",
        "  result = []\n",
        "  src = cv2.VideoCapture(str(video_path))\n",
        "\n",
        "  video_length = src.get(cv2.CAP_PROP_FRAME_COUNT)\n",
        "\n",
        "  need_length = 1 + (n_frames - 1) * frame_step\n",
        "\n",
        "  if need_length \u003e video_length:\n",
        "    start = 0\n",
        "  else:\n",
        "    max_start = video_length - need_length\n",
        "    start = random.randint(0, max_start + 1)\n",
        "\n",
        "  src.set(cv2.CAP_PROP_POS_FRAMES, start)\n",
        "  # ret is a boolean indicating whether read was successful, frame is the image itself\n",
        "  ret, frame = src.read()\n",
        "  result.append(format_frames(frame, output_size))\n",
        "\n",
        "  for _ in range(n_frames - 1):\n",
        "    for _ in range(frame_step):\n",
        "      ret, frame = src.read()\n",
        "    if ret:\n",
        "      frame = format_frames(frame, output_size)\n",
        "      result.append(frame)\n",
        "    else:\n",
        "      result.append(np.zeros_like(result[0]))\n",
        "  src.release()\n",
        "  result = np.array(result)[..., [2, 1, 0]]\n",
        "\n",
        "  return result\n",
        "\n",
        "def to_gif(images):\n",
        "  imageio.mimsave('./animation.gif', images, fps=10)\n",
        "  return embed.embed_file('./animation.gif')\n",
        "\n",
        "class FrameGenerator:\n",
        "  def __init__(self, path, n_frames, training = False):\n",
        "    \"\"\" Returns a set of frames with their associated label.\n",
        "\n",
        "      Args:\n",
        "        path: Video file paths.\n",
        "        n_frames: Number of frames.\n",
        "        training: Boolean to determine if training dataset is being created.\n",
        "    \"\"\"\n",
        "    self.path = path\n",
        "    self.n_frames = n_frames\n",
        "    self.training = training\n",
        "    self.class_names = sorted(set(p.name for p in self.path.iterdir() if p.is_dir()))\n",
        "    self.class_ids_for_name = dict((name, idx) for idx, name in enumerate(self.class_names))\n",
        "\n",
        "  def get_files_and_class_names(self):\n",
        "    video_paths = list(self.path.glob('*/*.avi'))\n",
        "    classes = [p.parent.name for p in video_paths]\n",
        "    return video_paths, classes\n",
        "\n",
        "  def __call__(self):\n",
        "    video_paths, classes = self.get_files_and_class_names()\n",
        "\n",
        "    pairs = list(zip(video_paths, classes))\n",
        "\n",
        "    if self.training:\n",
        "      random.shuffle(pairs)\n",
        "\n",
        "    for path, name in pairs:\n",
        "      video_frames = frames_from_video_file(path, self.n_frames)\n",
        "      label = self.class_ids_for_name[name] # Encode labels\n",
        "      yield video_frames, label"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nnn5gt5rKyqu"
      },
      "outputs": [],
      "source": [
        "# Helper functions below used are taken from following tutorials\n",
        "# https://www.tensorflow.org/tutorials/video/video_classification\n",
        "# https://www.tensorflow.org/tutorials/video/transfer_learning_with_movinet\n",
        "\n",
        "URL = 'https://storage.googleapis.com/thumos14_files/UCF101_videos.zip'\n",
        "download_dir = pathlib.Path('../UCF101_subset/')\n",
        "subset_paths = download_ufc_101_subset(URL,\n",
        "                        num_classes = 10,\n",
        "                        splits = {\"train\": 40, \"val\": 10, \"test\": 10},\n",
        "                        download_dir = download_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0ah-b7zZ-1zr"
      },
      "source": [
        "## Prepare train, valid and test dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Bvnz1T_CgxuT"
      },
      "outputs": [],
      "source": [
        "n_frames = 10\n",
        "CLASSES = sorted(os.listdir('../UCF101_subset/train/'))\n",
        "\n",
        "output_signature = (tf.TensorSpec(shape = (None, None, None, 3), dtype = tf.uint8, name='image'),\n",
        "                    tf.TensorSpec(shape = (), dtype = tf.int16, name='label'))\n",
        "\n",
        "train_ds = tf.data.Dataset.from_generator(FrameGenerator(subset_paths['train'], n_frames, training=True),\n",
        "                                          output_signature = output_signature)\n",
        "\n",
        "val_ds = tf.data.Dataset.from_generator(FrameGenerator(subset_paths['val'], n_frames),\n",
        "                                        output_signature = output_signature)\n",
        "\n",
        "test_ds = tf.data.Dataset.from_generator(FrameGenerator(subset_paths['test'], n_frames),\n",
        "                                         output_signature = output_signature)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uRjtR3N__WDi"
      },
      "source": [
        "## Write data as TFRecords"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9FeJjc7J_erR"
      },
      "source": [
        "### Helper function to convert data as TF Sequence Example"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "thYjfktWj5Bh"
      },
      "outputs": [],
      "source": [
        "def process_record(record):\n",
        "  \"\"\"\n",
        "    Convert training samples to SequenceExample format. For more detailed\n",
        "    explaination about SequenceExample, please check here\n",
        "    https://www.tensorflow.org/api_docs/python/tf/train/SequenceExample\n",
        "\n",
        "    Args:\n",
        "      record: training example with image frames and corresponding label.\n",
        "\n",
        "    Return:\n",
        "      Return a SequenceExample which represents a\n",
        "      sequence of features and some context.\n",
        "  \"\"\"\n",
        "  seq_example = tf.train.SequenceExample()\n",
        "  for example in record[0]:\n",
        "    seq_example.feature_lists.feature_list.get_or_create(\n",
        "        'image/encoded').feature.add().bytes_list.value[:] = [\n",
        "                tf.io.encode_jpeg(example).numpy()\n",
        "                ]\n",
        "  seq_example.context.feature[\n",
        "      'clip/label/index'].int64_list.value[:] = [record[1].numpy()]\n",
        "  return seq_example"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "i-9fsPOQp9Qy"
      },
      "outputs": [],
      "source": [
        "output_dir = '../ucf101_tfrecords/'\n",
        "LOG_EVERY = 100\n",
        "if not os.path.exists(output_dir):\n",
        "  os.mkdir(output_dir)\n",
        "\n",
        "\n",
        "def write_tfrecords(dataset, output_path, num_shards=1):\n",
        "  \"\"\"\n",
        "      Convert training samples to tfrecords\n",
        "\n",
        "      Args:\n",
        "        dataset: Dataset as a iterator in (tfds format).\n",
        "        output_path: Directory to store the tfrecords.\n",
        "        num_shards: Split the tfrecords to sepecific number of shards.\n",
        "  \"\"\"\n",
        "\n",
        "  writers = [\n",
        "        tf.io.TFRecordWriter(\n",
        "            output_path + '-%05d-of-%05d.tfrecord' % (i, num_shards))\n",
        "        for i in range(num_shards)\n",
        "    ]\n",
        "  for idx, record in enumerate(dataset):\n",
        "    if idx % LOG_EVERY == 0:\n",
        "      print('On image %d', idx)\n",
        "    seq_example = process_record(record)\n",
        "    writers[idx % num_shards].write(seq_example.SerializeToString())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z7gfIpOa_oIj"
      },
      "source": [
        "### Write training data as TFRecords"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C2BygDlYnFi8"
      },
      "outputs": [],
      "source": [
        "output_train_tfrecs = output_dir + 'train'\n",
        "write_tfrecords(train_ds, output_train_tfrecs, num_shards=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "daNmrC7T_t5i"
      },
      "source": [
        "### Write validation data as TFRecords"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bxMiQmM8GfQG"
      },
      "outputs": [],
      "source": [
        "output_val_tfrecs = output_dir + 'valid'\n",
        "write_tfrecords(val_ds, output_val_tfrecs, num_shards=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5pTf5iv__voH"
      },
      "source": [
        "### Write test data as TFRecords"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BRQad9B3GkHQ"
      },
      "outputs": [],
      "source": [
        "output_test_tfrecs = output_dir + 'test'\n",
        "write_tfrecords(test_ds, output_test_tfrecs, num_shards=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tHm2vpHy_3w_"
      },
      "source": [
        "## Experiment Configuration"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ffEqt4bD_-bQ"
      },
      "source": [
        "### Load the existing Configuration"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tf4N6_uW4UgK"
      },
      "outputs": [],
      "source": [
        "import yaml\n",
        "\n",
        "with open('./official/projects/video_ssl/configs/experiments/cvrl_linear_eval_k600.yaml', 'r') as file:\n",
        "  override_params = yaml.full_load(file)\n",
        "\n",
        "\n",
        "exp_config = exp_cfg.exp_factory.get_exp_config('video_ssl_linear_eval_kinetics600')\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V9QDyhgMAHs2"
      },
      "source": [
        "### Override the configuration parameters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HYDhSyC4AQ9c"
      },
      "outputs": [],
      "source": [
        "exp_config.override(override_params, is_strict=False)\n",
        "\n",
        "WIDTH, HEIGHT = 224, 224\n",
        "\n",
        "# Runtime configuration\n",
        "exp_config.runtime.distribution_strategy = \"mirrored\"\n",
        "\n",
        "# Task configuration\n",
        "exp_config.task.freeze_backbone = True\n",
        "exp_config.task.init_checkpoint = \"../r3d_1x_k600_800ep/r3d_1x_k600_800ep_backbone-1\"\n",
        "exp_config.task.init_checkpoint_modules = \"backbone\"\n",
        "\n",
        "# Model configuration\n",
        "exp_config.task.model.projection_dim = 10\n",
        "\n",
        "# Training data configuration\n",
        "exp_config.task.train_data.input_path = '../ucf101_tfrecords/train*'\n",
        "exp_config.task.train_data.num_classes=10\n",
        "exp_config.task.train_data.global_batch_size = 2\n",
        "exp_config.task.train_data.min_image_size = WIDTH\n",
        "exp_config.task.train_data.num_examples = 400\n",
        "exp_config.task.train_data.feature_shape = (n_frames, HEIGHT, WIDTH, 3)\n",
        "\n",
        "# Validation data configuration\n",
        "exp_config.task.validation_data.num_classes=10\n",
        "exp_config.task.validation_data.input_path = '../ucf101_tfrecords/valid*'\n",
        "exp_config.task.validation_data.global_batch_size = 2\n",
        "exp_config.task.validation_data.min_image_size = WIDTH\n",
        "exp_config.task.validation_data.num_examples = 100\n",
        "exp_config.task.validation_data.feature_shape = (n_frames, HEIGHT, WIDTH, 3)\n",
        "\n",
        "# Trainer configuration\n",
        "\n",
        "exp_config.trainer.train_steps = 2000\n",
        "exp_config.trainer.checkpoint_interval = 200\n",
        "exp_config.trainer.steps_per_loop = 200\n",
        "exp_config.trainer.summary_interval = 200\n",
        "exp_config.trainer.validation_interval = 200\n",
        "exp_config.trainer.validation_steps = 200\n",
        "exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = 2000\n",
        "exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.008\n",
        "exp_config.trainer.optimizer_config.warmup.linear.warmup_learning_rate = 0.007\n",
        "exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 200"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rmxwvgimiSQu"
      },
      "source": [
        "### Set up the distribution strategy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RpD__dO5xJUC"
      },
      "outputs": [],
      "source": [
        "# Detect hardware\n",
        "try:\n",
        "  tpu_resolver = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection\n",
        "except ValueError:\n",
        "  tpu_resolver = None\n",
        "  gpus = tf.config.experimental.list_logical_devices(\"GPU\")\n",
        "\n",
        "# Select appropriate distribution strategy\n",
        "if tpu_resolver:\n",
        "  tf.config.experimental_connect_to_cluster(tpu_resolver)\n",
        "  tf.tpu.experimental.initialize_tpu_system(tpu_resolver)\n",
        "  distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu_resolver)\n",
        "  print('Running on TPU ', tpu_resolver.cluster_spec().as_dict()['worker'])\n",
        "elif len(gpus) \u003e 1:\n",
        "  distribution_strategy = tf.distribute.MirroredStrategy([gpu.name for gpu in gpus])\n",
        "  print('Running on multiple GPUs ', [gpu.name for gpu in gpus])\n",
        "elif len(gpus) == 1:\n",
        "  distribution_strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU\n",
        "  print('Running on single GPU ', gpus[0].name)\n",
        "else:\n",
        "  distribution_strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU\n",
        "  print('Running on CPU')\n",
        "\n",
        "print(\"Number of accelerators: \", distribution_strategy.num_replicas_in_sync)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iIT9AgxoAday"
      },
      "source": [
        "### Display the final configuration"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_GVF3uDw8QVt"
      },
      "outputs": [],
      "source": [
        "pp.pprint(exp_config.as_dict())\n",
        "display.Javascript('google.colab.output.setIframeHeight(\"500px\");')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VZF9G3M5i1Lm"
      },
      "source": [
        "## Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.\n",
        "\n",
        "The `Task` object has all the methods necessary for building the dataset, building the model, and running training \u0026 evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OPn2AMRZK-zZ"
      },
      "outputs": [],
      "source": [
        "model_dir = '../trained_model/'\n",
        "\n",
        "with distribution_strategy.scope():\n",
        "  task = task_factory.get_task(exp_config.task, logging_dir=model_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6M0rlFn-igH7"
      },
      "source": [
        "## Visualization of Train Data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CGxZEl6ijF2a"
      },
      "outputs": [],
      "source": [
        "frames, label = list(train_ds.take(1))[0]\n",
        "print(CLASSES[label])\n",
        "to_gif(frames)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bnMqrCtgo4UO"
      },
      "source": [
        "## Train and evaluate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "watwZ5mIwtvk"
      },
      "outputs": [],
      "source": [
        "model, eval_logs = train_lib.run_experiment(\n",
        "    distribution_strategy=distribution_strategy,\n",
        "    task=task,\n",
        "    mode='train_and_eval',\n",
        "    params=exp_config,\n",
        "    model_dir=model_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZJBzAMJeo9JN"
      },
      "source": [
        "## Load logs in tensorboard"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c3smE05dIgAV"
      },
      "outputs": [],
      "source": [
        "%load_ext tensorboard\n",
        "%tensorboard --logdir '../trained_model'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WSTLebJGo_8a"
      },
      "source": [
        "## Saving and exporting the trained model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "o4lhmgy9I6Du"
      },
      "outputs": [],
      "source": [
        "export_dir = '../exported_model/'\n",
        "\n",
        "export_saved_model_lib.export_inference_graph(\n",
        "    input_type='image_tensor',\n",
        "    batch_size=1,\n",
        "    input_image_size=[n_frames, HEIGHT, WIDTH],\n",
        "    params=exp_config,\n",
        "    checkpoint_path=tf.train.latest_checkpoint(model_dir),\n",
        "    export_dir=export_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "10Yn7-pBpns-"
      },
      "source": [
        "## Importing SavedModel"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u3fghz-jVrcR"
      },
      "outputs": [],
      "source": [
        "imported = tf.saved_model.load(export_dir)\n",
        "model_fn = imported.signatures['serving_default']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D6qzUv1lpsRL"
      },
      "source": [
        "## Visualize predictions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1McexxfTh-Bl"
      },
      "outputs": [],
      "source": [
        "frames, label = list(test_ds.shuffle(buffer_size=90).take(1))[0]\n",
        "frames = tf.expand_dims(frames, axis=0)\n",
        "result = model_fn(frames)\n",
        "predicted_label = tf.argmax(result['probs'][0])\n",
        "print(f\"Actual: {CLASSES[label]}\")\n",
        "print(f\"Predicted: {CLASSES[predicted_label]}\")\n",
        "to_gif(frames[0])"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "video_ssl.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
